{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "faa09879-9128-4864-8bb5-945ef9b8e84c",
   "metadata": {},
   "source": [
    "# RAG: Using Gemma LLM locally for question answering on private data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d047438b-6f18-47ed-aac9-12c741cefd06",
   "metadata": {},
   "source": [
    "In this notebook, our aim is to develop a RAG system utilizing [Google's Gemma](https://ai.google.dev/gemma) model. We'll generate vectors with [Elastic's ELSER](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html) model and store them in Elasticsearch. Additionally, we'll explore semantic retrieval techniques and present the top search results as a context window to the Gemma model. Furthermore, we'll utilize the [Hugging Face transformer](https://huggingface.co/google/gemma-2b-it) library to load Gemma on a local environment."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bd3acec-d490-4139-bab1-b874e1e7db8d",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef406b8a-03fb-49c5-baed-18e03bcd36d9",
   "metadata": {},
   "source": [
    "**Elastic Credentials** - Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials (`ELASTIC_CLOUD_ID`,` ELASTIC_API_KEY`).\n",
    "\n",
    "**Hugging Face Token** - To get started with the [Gemma](https://huggingface.co/google/gemma-2b-it) model, it is necessary to agree to the terms on Hugging Face and generate the [access token](https://huggingface.co/docs/hub/en/security-tokens) with `write` role.\n",
    "\n",
    "**Gemma Model** - We're going to use [gemma-2b-it](https://huggingface.co/google/gemma-2b-it), though Google has released 4 open models. You can use any of them i.e. [gemma-2b](https://huggingface.co/google/gemma-2b), [gemma-7b](https://huggingface.co/google/gemma-7b), [gemma-7b-it](https://huggingface.co/google/gemma-7b-it)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac91d7a3-1198-4b11-a9c5-50028abc861b",
   "metadata": {},
   "source": [
    "## Install packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fda41538-444c-48d7-80a0-b34b2e158b82",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install -q -U elasticsearch langchain transformers huggingface_hub torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15c2e924-e5a2-439b-8e98-f13a162db7fe",
   "metadata": {},
   "source": [
    "## Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7219411b-fae6-4c2a-b170-796bc30ed073",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from getpass import getpass\n",
    "from urllib.request import urlopen\n",
    "\n",
    "from elasticsearch import Elasticsearch, helpers\n",
    "from langchain.text_splitter import CharacterTextSplitter\n",
    "from langchain.vectorstores import ElasticsearchStore\n",
    "from langchain import HuggingFacePipeline\n",
    "from langchain.chains import RetrievalQA\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.schema.output_parser import StrOutputParser\n",
    "from langchain.schema.runnable import RunnablePassthrough\n",
    "from huggingface_hub import login\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from transformers import AutoTokenizer, pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "182a413f-e7fd-4361-8096-90736d3df33e",
   "metadata": {},
   "source": [
    "## Get Credentials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b184b3a5-0cc8-43f9-b15d-f5ccf48f574b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ELASTIC_API_KEY = getpass(\"Elastic API Key :\")\n",
    "ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID :\")\n",
    "elastic_index_name = \"gemma-rag\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2efbd81-70b9-409c-ab5f-796d538b42a1",
   "metadata": {},
   "source": [
    "## Add documents"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "161dfb9d-f11f-4de5-8489-6464ade0cdb2",
   "metadata": {},
   "source": [
    "### Let's download the sample dataset and deserialize the document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49427546-7b37-48f4-a6fe-395736ea2d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json\"\n",
    "\n",
    "response = urlopen(url)\n",
    "\n",
    "workplace_docs = json.loads(response.read())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3bf0104-8b31-4b39-ad21-b372fd1fa0db",
   "metadata": {},
   "source": [
    "### Split Documents into Passages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79e55ed1-418e-48ed-b3e3-d28e10744eb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "metadata = []\n",
    "content = []\n",
    "\n",
    "for doc in workplace_docs:\n",
    "    content.append(doc[\"content\"])\n",
    "    metadata.append(\n",
    "        {\n",
    "            \"name\": doc[\"name\"],\n",
    "            \"summary\": doc[\"summary\"],\n",
    "            \"rolePermissions\": doc[\"rolePermissions\"],\n",
    "        }\n",
    "    )\n",
    "\n",
    "text_splitter = CharacterTextSplitter(chunk_size=50, chunk_overlap=0)\n",
    "docs = text_splitter.create_documents(content, metadatas=metadata)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4264bc1b-23b1-4547-a7f0-670944c3e605",
   "metadata": {},
   "source": [
    "## Index Documents into Elasticsearch using ELSER\n",
    "\n",
    "Before we begin indexing, ensure you have [downloaded and deployed the ELSER model](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#download-deploy-elser) in your deployment and is running on the ML node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb1db78e-e40a-4a5c-9d15-75ee2a1d0994",
   "metadata": {},
   "outputs": [],
   "source": [
    "es = ElasticsearchStore.from_documents(\n",
    "    docs,\n",
    "    es_cloud_id=ELASTIC_CLOUD_ID,\n",
    "    es_api_key=ELASTIC_API_KEY,\n",
    "    index_name=elastic_index_name,\n",
    "    strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(\n",
    "        model_id=\".elser_model_2\"\n",
    "    ),\n",
    ")\n",
    "\n",
    "es"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02b1ead9-c442-40e9-ba81-d4d286ea878b",
   "metadata": {},
   "source": [
    "## Hugging Face login"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f651e4-e760-4b59-a8a3-57c58dfc229f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7454f551-71a9-4310-bb2a-3fe0e683daab",
   "metadata": {},
   "source": [
    "## Initialize the tokenizer with the model (`google/gemma-2b-it`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e1d98eb-0f4e-4c41-a851-125b75502963",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"google/gemma-2b-it\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-2b-it\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11a12596-2ac1-4101-b189-d21d53d33b04",
   "metadata": {},
   "source": [
    "## Create a `text-generation` pipeline and initialize with LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "623e74fb-5707-44f7-9dd8-d9499f7ab61e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe = pipeline(\n",
    "    \"text-generation\",\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    max_new_tokens=1024,\n",
    ")\n",
    "\n",
    "llm = HuggingFacePipeline(\n",
    "    pipeline=pipe,\n",
    "    model_kwargs={\"temperature\": 0.7},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49ce0e72-e419-4310-85e9-09077d6c40b2",
   "metadata": {},
   "source": [
    "## Format Docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c07a75-9220-4a82-a92e-3fc2727ad3ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_docs(docs):\n",
    "    return \"\\n\\n\".join(doc.page_content for doc in docs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6266222-6ec3-495a-8f14-460549bab89d",
   "metadata": {},
   "source": [
    "## Create a chain using Prompt template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec203d1a-104b-4583-9ba1-a6b4b0354367",
   "metadata": {},
   "outputs": [],
   "source": [
    "retriever = es.as_retriever(search_kwargs={\"k\": 5})\n",
    "\n",
    "template = \"\"\"Answer the question based only on the following context:\\n\n",
    "\n",
    "{context}\n",
    "\n",
    "Question: {question}\n",
    "\"\"\"\n",
    "\n",
    "prompt = ChatPromptTemplate.from_template(template)\n",
    "\n",
    "chain = (\n",
    "    {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
    "    | prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ae892dd-7442-4d4d-a804-1d717266e596",
   "metadata": {},
   "source": [
    "## Ask question"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "ba312f17-44ae-423d-89a0-ea01eccd85b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Answer: The sales goals are to increase revenue, expand market share, and strengthen customer relationships in our target markets.'"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain.invoke(\"What are the sales goals?\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
